import gymnasium as gym
from gymnasium import spaces
import numpy as np
import random
import matplotlib.pyplot as plt
import os


SNAKE_LEN_GOAL = 30

class SnakeEnv(gym.Env):
    def __init__(self, render=False, snake=None, board_size=6):
        super(SnakeEnv, self).__init__()
        self.should_render = render
        self.max_moves = int(board_size**2 + 5)
        self.max_apple_bonus_moves = board_size * 2
        self.mov_closer_rew = 1 / self.max_apple_bonus_moves * 3
        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Box(low=0, high=1, shape=(1, board_size**2 + 2 + 1 + 4), dtype=np.float32)
        if self.should_render:
            plt.ion()
            self.fig, self.ax = plt.subplots()
            self.full_color = np.array(255, dtype=np.uint8)
        self.done = False
        self.board_size = board_size
        self.max_len = board_size * board_size
        self.ham_arr = self.gen_ham_array()
        self.norm_mul = np.array([[2],[3]])
        self.img = np.zeros((board_size, board_size), dtype=np.uint8)
        self.custom_snake = snake
        self.snake_body = None
        self.body_positions_set = None
        self.remaining_moves = np.array([self.max_moves], dtype=np.int32)
        self.directions = [
            np.array([-1, 0], dtype=np.int8),  # Left
            np.array([1, 0], dtype=np.int8),   # Right
            np.array([0, 1], dtype=np.int8),  # Down
            np.array([0, -1], dtype=np.int8)    # Up
        ]
        
        
    def step(self, action):
        self.reward = 0
        # Change the head position based on the button direction
        self.snake_head += self.directions[action]
        col_with_apple = False
        # Increase Snake length on eating apple
        if self.snake_head[0] == self.apple_position[0] and self.snake_head[1] == self.apple_position[1]:
            self.snake_position = np.append([self.snake_head], self.snake_position, axis=0)
            if len(self.snake_position) == self.max_len:
                self.done = True
            else:
                self.apple_position, self.score = self.collision_with_apple(
                    self.snake_head, self.score
                )
            col_with_apple = True
        else:
            self.snake_position[1:] = self.snake_position[:-1]
            self.snake_position[0] = self.snake_head.copy()

        self.img.fill(0)
        # Display Apple
        if not self.done:
            self.img[self.apple_position[1], self.apple_position[0]] = 10

        # On collision kill the snake
        if (
            self.collision_with_boundaries(self.snake_head) == 1
            or self.collision_with_self(self.snake_position, self.snake_head) == 1
        ):
            self.done = True

        # Display Snake
        if not self.done or col_with_apple:
            self.draw_snake(self.img, self.snake_position)
        else:
            self.draw_snake(self.img, self.snake_position, done=True)


        dists_neg = self.snake_head - self.apple_position
        dists = np.abs(dists_neg.copy())
        past_dists = np.abs(self.snake_position[1] - self.apple_position)
        if self.done and not col_with_apple:
            self.reward = -5
        elif col_with_apple:
            if self.mins[0] != 0:
                self.reward = self.mov_closer_rew
            elif self.mins[1] != 0:
                self.reward = self.mov_closer_rew
            
            self.reward += 2
            self.mins_neg = self.snake_head - self.apple_position
            self.mins = np.abs(self.mins_neg.copy())
        elif dists[0] < self.mins[0]:
            self.reward = self.mov_closer_rew
            self.mins_neg[0] = dists_neg[0]
            self.mins[0] = dists[0]
        elif dists[1] < self.mins[1]:
            self.reward = self.mov_closer_rew
            self.mins_neg[1] = dists_neg[1]
            self.mins[1] = dists[1]
        elif dists[0] < past_dists[0]:
            self.reward = 0.001
        elif dists[1] < past_dists[1]:
            self.reward = 0.001
        else:
            self.reward = -0.002

        self.remaining_moves[0] -= 1
        if not self.done and self.remaining_moves[0] == 0:
            self.done = True
            self.reward = -5
        elif col_with_apple:
            moves_used = self.max_moves - self.remaining_moves[0]
            if moves_used <= self.max_apple_bonus_moves:
                self.reward += 4 - moves_used / self.max_apple_bonus_moves * 4
            self.remaining_moves[0] = self.max_moves
        
        
        if self.should_render:
            self.render()
        self.observation = self.flatten_obs(self.img, self.mins_neg, self.remaining_moves, self.safe_moves_array(self.snake_head, self.snake_position))
        info = {}
        return (
            self.observation,
            self.reward,
            self.done,
            False,
            info,
        )

    def reset(self, seed=None):
        self.done = False
        self.snake_body = None
        self.body_positions_set = None
        self.img.fill(0)
        self.remaining_moves[0] = self.max_moves
        # Initial Snake and Apple position
        self.snake_position = self.custom_snake.copy() if self.custom_snake is not None else self.random_start_position()
        self.apple_position = self.collision_with_apple(0, 0)[0]

        # Display Apples
        self.img[self.apple_position[1], self.apple_position[0]] = 10
        # Display Snake
        self.draw_snake(self.img, self.snake_position)
        self.snake_head = np.copy(self.snake_position[0])
        self.mins_neg = self.snake_head - self.apple_position
        self.mins = np.abs(self.mins_neg.copy())

        self.score = 0
        self.reward = 0
        if self.should_render:
            self.render()
        self.observation = self.flatten_obs(self.img, self.mins_neg, self.remaining_moves, self.safe_moves_array(self.snake_head, self.snake_position))
        info = {}
        return (
            self.observation,
            info,
        )  # reward, done, info can't be included
    
    # flattens obs into a 1d array and normalizes it
    def flatten_obs(self, image, info, remaining_moves, safe_moves):
        image = np.float32(image.flatten()) / 10
        info = (np.float32(info) + self.board_size) / (self.board_size * 2)
        remaining_moves = np.float32(remaining_moves) / self.max_moves
        safe_moves = np.float32(safe_moves)
        return np.hstack((image, info, remaining_moves, safe_moves))
        
    def draw_snake(self, img, snake_position, done = False):
        # 1 - head, 2 - body Left, 3 - body Up, 4 - body Down, 5 - body Right, 6 - tail Left, 7 - tail Up, 8 - tail Down, 9 - tail Right
        if self.snake_body is None:
            self.snake_body = self.get_snake_body(snake_position)
        else:
            if len(snake_position) != len(self.snake_body):
                self.snake_body = np.hstack((1, self.snake_body))
            else:
                self.snake_body[1:] = self.snake_body[:-1]
                self.snake_body[0] = 1
                self.snake_body[-1] += 4
            sub = snake_position[1] - snake_position[0]
            total = sub[0] * 2 + sub[1] * 3
            if total < 0:
                total += 7
            self.snake_body[1] = total
            
        snake_body = self.snake_body
        if done:
            img[snake_position[1:, 1], snake_position[1:, 0]] = snake_body[1:]
        else:
            img[snake_position[:, 1], snake_position[:, 0]] = snake_body

    # gets the snake body as a 1d array
    def get_snake_body(self, snake_position):
        # 1 - head, 2 - body Left, 3 - body Up, 4 - body Down, 5 - body Right, 6 - tail Left, 7 - tail Up, 8 - tail Down, 9 - tail Right
        snake_body = (snake_position[1:] - snake_position[:-1])
        snake_body = np.matmul(snake_body, self.norm_mul)
        snake_body[snake_body < 0] += 7
        snake_body[-1] += 4
        snake_body = np.vstack((1, snake_body))
        snake_body = snake_body.reshape(-1)
        return snake_body

    # renders the game with matplotlib
    def render(self, mode="human"):
        self.ax.clear()
        cp = np.zeros((self.board_size, self.board_size, 3), dtype=np.uint8)
        snake_colors = np.arange(0, 256, 255 / (len(self.snake_position) - 1), dtype=np.uint8)
        if not self.done:
            cp[self.snake_position[:, 1], self.snake_position[:, 0], 1] = 128
            cp[self.snake_position[:, 1], self.snake_position[:, 0], 2] = snake_colors 
        else:
            cp[self.snake_position[1:, 1], self.snake_position[1:, 0], 1] = 128
            cp[self.snake_position[1:, 1], self.snake_position[1:, 0], 2] = snake_colors[1:]
        cp[self.apple_position[1], self.apple_position[0], 0] = self.full_color
        self.ax.imshow(cp)
        self.fig.canvas.draw()
        plt.pause(0.01)

    # checks if the snake has collided with the apple
    def collision_with_apple(self, apple_position, score):
        apple_position = np.random.randint(0, self.board_size, (2,), dtype=np.int32)
        body_positions_set = self.get_body_positions_set(self.snake_position)
        while tuple(apple_position) in body_positions_set:
            apple_position = np.random.randint(0, self.board_size, (2,), dtype=np.int32)
        score += 1
        return apple_position, score

    # checks if the snake has collided with the boundaries
    def collision_with_boundaries(self, snake_head):
        snake_head_x = snake_head[0]
        snake_head_y = snake_head[1]
        if (
            snake_head_x >= self.board_size
            or snake_head_x < 0
            or snake_head_y >= self.board_size
            or snake_head_y < 0
        ):
            return 1
        else:
            return 0

    # checks if the snake has collided with itself
    def collision_with_self(self, snake_position, snake_head):
        body_positions_set = self.get_body_positions_set(snake_position)
        if len(body_positions_set) != len(snake_position):
            return 1
        else:
            return 0
        
    # gets the body positions as a set
    def get_body_positions_set(self, snake_position):
        # if none, goes through all the positions and adds them to the set
        if self.body_positions_set is None:
            self.body_positions_set = set(tuple(pos) for pos in snake_position)
            self.last_tail_tuple = tuple(snake_position[-1])
        elif len(self.body_positions_set) != len(snake_position):
            # if the length of the set is not equal to the length of the snake, then it must have grown
            tuple_snake_head = tuple(snake_position[0])
            self.body_positions_set.add(tuple_snake_head)
        elif self.last_tail_tuple != tuple(snake_position[-1]):
            # if the last tail tuple is not equal to the last position of the snake, then it must have moved
            self.body_positions_set.remove(self.last_tail_tuple)
            tuple_snake_head = tuple(snake_position[0])
            self.body_positions_set.add(tuple_snake_head)
            self.last_tail_tuple = tuple(snake_position[-1])

        return self.body_positions_set
    
    # gets the safe moves as a 1d array
    # 1 is safe, 0 is not safe
    def safe_moves_array(self, snake_head, snake_position):
        body_positions_set = self.get_body_positions_set(snake_position)
        snake_head_x = snake_head[0]
        snake_head_y = snake_head[1]
        tail_tuple = tuple(snake_position[-1])
        safe_moves = np.ones((4,), dtype=np.uint8)
        move_left_tuple = (snake_head_x - 1, snake_head_y)
        move_right_tuple = (snake_head_x + 1, snake_head_y)
        move_up_tuple = (snake_head_x,  snake_head_y - 1)
        move_down_tuple = (snake_head_x, snake_head_y + 1)
        # checks tail tuple, becuase it is safe to move into the tail
        if (move_left_tuple in body_positions_set or move_left_tuple[0] < 0) and move_left_tuple != tail_tuple:
            safe_moves[0] = 0
        if (move_right_tuple in body_positions_set or move_right_tuple[0] >= self.board_size) and move_right_tuple != tail_tuple:
            safe_moves[1] = 0
        if (move_up_tuple in body_positions_set or move_up_tuple[1] < 0) and move_up_tuple != tail_tuple:
            safe_moves[2] = 0
        if (move_down_tuple in body_positions_set or move_down_tuple[1] >= self.board_size) and move_down_tuple != tail_tuple:
            safe_moves[3] = 0
        # left, right, up, down
        return safe_moves
    
    # gets a random starting position for the snake that is at least 3 long using the hamiltonian path to prevent overfitting
    def random_start_position(self):
        rand_start = random.randint(0, self.max_len - 3)
        snake = self.ham_arr[rand_start:rand_start+3].copy().astype(np.int32)
        if (0 == random.randint(0, 1)):
            snake = np.flip(snake, axis=0)
        return snake

    # generates the hamiltonian path for the board size or loads it from a file
    def gen_ham_array(self):
        path_str = f"/home/nebul/custom/npcache/{self.board_size}x{self.board_size}.npy"
        path_exists = os.path.exists(path_str)
        if path_exists:
            ham_arr = np.load(path_str)
        else:
            n = self.board_size
            ham_arr = np.zeros((n**2, 2), dtype=np.uint8)
            start = np.array([0, 1], dtype=np.uint8)
            for i in range(n**2 - n):
                ham_arr[i] = start.copy()
                if start[0] % 2 == 0:
                    start += np.array([0, 1], dtype=np.uint8)
                else:
                    start -= np.array([0, 1], dtype=np.uint8)
                if start[1] == n:
                    start += np.array([1, -1], dtype=np.uint8)
                elif start[1] == 0:
                    start += np.array([1, 1], dtype=np.uint8)
            start = np.array([n  - 1, 0], dtype=np.uint8)
            for i in range(n):
                ham_arr[n**2 - n + i] = start.copy()
                start += np.array([-1, 0], dtype=np.uint8)
            np.save(path_str, ham_arr)
        return ham_arr

